InĀ [49]:
import random
import process_data_set
import importlib
import os
import cv2
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import cv2
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import time
from skimage.measure import shannon_entropy
from ml_tools_utils import utils
InĀ [55]:
utils.pandas_config(pd)
utils.plt_config(plt)

sns.set_theme(style="darkgrid", palette="pastel")
plt.style.use("fivethirtyeight")

Dataset processing and analysis¶

Importing and cleaning up the dataset:¶

The dataset included duplicate versions of some image files which need to be removed (checked using imagehash)

InĀ [52]:
importlib.reload(process_data_set)

process_data_set.download_ds(process_data_set.TEMP_DATASET_NAME)
InĀ [53]:
dups = process_data_set.find_duplicates(process_data_set.UNPROC_DATASET_LOC)
dups
Out[53]:
Class Duplicate Count Total Images Proportion
0 Agaricus 2 353 0.005666
1 Amanita 2 750 0.002667
2 Boletus 2 1073 0.001864
3 Cortinarius 2 836 0.002392
4 Entoloma 0 364 0.000000
5 Hygrocybe 1 316 0.003165
6 Lactarius 63 1563 0.040307
7 Russula 4 1147 0.003487
8 Suillus 0 311 0.000000
9 Total 76 6713 0.011321

There are some corrupt and unreadable images in the dataset that also need to be removed:

InĀ [54]:
process_data_set.verify_and_clean_images(process_data_set.UNPROC_DATASET_LOC)
Removing corrupt image: dataset_temp\Mushrooms\Russula\092_43B354vYxm8.jpg due to image file is truncated (92 bytes not processed)
Out[54]:
[WindowsPath('dataset_temp/Mushrooms/Russula/092_43B354vYxm8.jpg')]

Image Analysis¶

InĀ [56]:
def get_image_paths(data_dir):
    image_paths = []
    for subdir, _, files in os.walk(data_dir):
        for file in files:
            if file.endswith(('.png', '.jpg', '.jpeg')):
                image_paths.append(os.path.join(subdir, file))
    return image_paths


def process_image(image_path, bins=32):
    image = cv2.imread(image_path)
    if image is None:
        return None

    color_type = 'Unknown'
    if len(image.shape) == 2:
        color_type = 'Grayscale'
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    elif len(image.shape) == 3:
        if image.shape[2] == 3:
            color_type = 'Color'
        else:
            color_type = f'Other ({image.shape[2]} channels)'

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    original_shape = image.shape
    aspect_ratio = original_shape[1] / original_shape[0]
    image = cv2.resize(image, (100, 100))  # downsample image to reduce  size

    class_name = os.path.basename(os.path.dirname(image_path))
    color_distributions = {class_name: {'R': [], 'G': [], 'B': []}}
    for channel, color in enumerate(['R', 'G', 'B']):
        hist = cv2.calcHist([image], [channel], None, [bins], [0, 256])
        hist = hist.flatten() / hist.sum()
        color_distributions[class_name][color] = hist

    variance = np.var(image, axis=(0, 1)).mean()
    unique_colors = len(np.unique(image.reshape(-1, image.shape[2]), axis=0))
    entropy = shannon_entropy(image)

    return color_distributions, color_type, original_shape, aspect_ratio, variance, unique_colors, entropy, image_path


def merge_color_distributions(distributions_list):
    class_counts = {}
    merged_distributions = {}

    for result in distributions_list:
        if result is None:
            continue
        distributions, _, _, _, _, _, _, _ = result
        for class_name, color_dist in distributions.items():
            if class_name not in merged_distributions:
                merged_distributions[class_name] = {
                    'R': np.zeros_like(color_dist['R']),
                    'G': np.zeros_like(color_dist['G']),
                    'B': np.zeros_like(color_dist['B'])
                }
                class_counts[class_name] = 0
            for color in ['R', 'G', 'B']:
                merged_distributions[class_name][color] += color_dist[color]
            class_counts[class_name] += 1

    for class_name in merged_distributions:
        for color in ['R', 'G', 'B']:
            merged_distributions[class_name][color] /= class_counts[class_name]

    return merged_distributions


def get_color_distributions(image_paths, max_workers=None):
    start_time = time.time()

    if max_workers is None:
        max_workers = os.cpu_count()

    print(f"Running on {max_workers} workers")

    color_distributions_list = []
    color_types = []
    shapes = []
    aspect_ratios = []
    variances = []
    unique_colors = []
    entropies = []
    image_paths_list = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_image, image_path): image_path for image_path in image_paths}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing images"):
            result = future.result()
            if result is not None:
                color_distributions_list.append(result)
                color_types.append(result[1])
                shapes.append(result[2])
                aspect_ratios.append(result[3])
                variances.append(result[4])
                unique_colors.append(result[5])
                entropies.append(result[6])
                image_paths_list.append(result[7])

    color_distributions = merge_color_distributions(color_distributions_list)
    elapsed_time = time.time() - start_time
    print(f"Total processing time: {elapsed_time:.2f} seconds")

    return color_distributions, color_types, shapes, aspect_ratios, variances, unique_colors, entropies, image_paths_list


def summarize_image_types(image_paths, color_types):
    summary = pd.DataFrame({'image_path': image_paths, 'color_type': color_types})
    summary_table = summary['color_type'].value_counts().reset_index()
    summary_table.columns = ['Color Type', 'Count']
    return summary_table


def summarize_dimensions(image_paths, shapes):
    summary = pd.DataFrame({'image_path': image_paths, 'shape': shapes})
    width_summary = summary['shape'].apply(lambda x: x[1]).value_counts().reset_index()
    width_summary.columns = ['Width', 'Count']
    height_summary = summary['shape'].apply(lambda x: x[0]).value_counts().reset_index()
    height_summary.columns = ['Height', 'Count']
    aspect_ratio_summary = summary['shape'].apply(lambda x: x[1] / x[0]).value_counts().reset_index()
    aspect_ratio_summary.columns = ['Aspect Ratio', 'Count']
    return width_summary, height_summary, aspect_ratio_summary


def summarize_color_metrics(image_paths, variances, unique_colors, entropies):
    summary = pd.DataFrame({
        'image_path': image_paths,
        'variance': variances,
        'unique_colors': unique_colors,
        'entropy': entropies
    })
    return summary


def plot_color_distributions(color_distributions, bins=32):
    global_min, global_max = 0, 0
    for class_name, distributions in color_distributions.items():
        for color in ['R', 'G', 'B']:
            max_value = max(distributions[color])
            if max_value > global_max:
                global_max = max_value

    for class_name, distributions in color_distributions.items():
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        for i, (color, ax) in enumerate(zip(['R', 'G', 'B'], axes)):
            ax.bar(range(bins), distributions[color], color=color.lower(), alpha=0.7)
            ax.set_title(f'{class_name} - {color} Channel Distribution')
            ax.set_xlabel('Intensity')
            ax.set_ylabel('Density')
            ax.set_ylim(0, global_max)
        plt.tight_layout()
        plt.show()


def plot_filtered_images_by_entropy(filtered_image_paths, filtered_entropies, images_per_row=4):
    num_images = len(filtered_image_paths)
    num_rows = (num_images + images_per_row - 1) // images_per_row
    fig, axes = plt.subplots(num_rows, images_per_row, figsize=(20, 5 * num_rows))
    axes = axes.flatten()

    for ax, (image_path, entropy) in zip(axes, zip(filtered_image_paths, filtered_entropies)):
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        class_name = os.path.basename(os.path.dirname(image_path))
        file_name = os.path.basename(image_path)
        ax.imshow(image)
        ax.set_title(f'{class_name}/{file_name}\nEntropy: {entropy:.2f}')
        ax.axis('off')

    for ax in axes[num_images:]:
        ax.axis('off')

    plt.tight_layout()
    plt.show()
InĀ [57]:
image_paths = get_image_paths(process_data_set.UNPROC_DATASET_LOC)
InĀ [58]:
color_distributions, color_types, shapes, aspect_ratios, variances, unique_colors, entropies, image_paths_list = get_color_distributions(
    image_paths)
Running on 28 workers
Processing images: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 6637/6637 [00:26<00:00, 248.23it/s]
Total processing time: 26.97 seconds

Image color summary¶

There seem to be no grayscale images and all images have 3 color channels.

InĀ [59]:
summary_table = summarize_image_types(image_paths, color_types)
summary_table
Out[59]:
Color Type Count
0 Color 6637

Dimensions¶

There seems to be alot of variance between images sizes and aspect ratios this is a significant concern because models like ResNet, require fixed input sizes (224x224 in this case)

We are using FastAI's 'ImageDataLoaders' which handles the resizing and add padding to samples which are not square.

InĀ [60]:
width_summary, height_summary, aspect_ratio_summary = summarize_dimensions(image_paths, shapes)

width_stats = width_summary["Width"].describe()
height_stats = height_summary["Height"].describe()
aspect_ratio_stats = aspect_ratio_summary["Aspect Ratio"].describe()

summary_df = pd.DataFrame({
    'Width': width_stats,
    'Height': height_stats,
    'Aspect Ratio': aspect_ratio_stats
})

display(summary_df)
Width Height Aspect Ratio
count 363.000000 505.000000 1156.000000
mean 709.586777 623.914851 1.325247
std 207.637554 183.902749 0.237855
min 259.000000 152.000000 0.561250
25% 572.500000 487.000000 1.226641
50% 702.000000 613.000000 1.336744
75% 797.500000 749.000000 1.462122
max 1280.000000 1024.000000 2.857143

Color Variance and Entropy¶

We'll further analyze the images using these metrics:

  • Average variance of color channels in the all images:

    • Variance = 0: All pixels in the image have the same color.
    • High Variance: Indicates images with diverse color pixels.
  • Number of unique colors in each image

  • Entropy (shannon_entropy).

    • Scale: 0 to log2(N), where N is the number of possible pixel values (0 to 8 for 256 grayscale values).
      • Min Entropy = 0: Perfectly uniform image (single color).
      • High Entropy: Indicates images with a wide variety of colors and patterns.
InĀ [61]:
image_entropy_summary = summarize_color_metrics(image_paths, variances, unique_colors, entropies)
variance_summary = image_entropy_summary['variance'].describe()
unique_colors_summary = image_entropy_summary['unique_colors'].describe()
entropy_summary = image_entropy_summary['entropy'].describe()

summary_df = pd.DataFrame({
    'Variance': variance_summary,
    'Unique Color': unique_colors_summary,
    'Entropy': entropy_summary
})

display(summary_df)
Variance Unique Color Entropy
count 6637.000000 6637.000000 6637.000000
mean 3339.696198 9117.627844 7.575777
std 1040.674133 858.859522 0.289316
min 0.000000 1.000000 0.000000
25% 2629.536395 8884.000000 7.488475
50% 3218.788643 9367.000000 7.625352
75% 3928.063424 9661.000000 7.732022
max 11091.372221 9968.000000 7.979296

This information will be used to filter out the samples which might not be useful for image analysis

InĀ [62]:
image_entropy_summary = summarize_color_metrics(image_paths, variances, unique_colors, entropies)
image_entropy_summary['class'] = image_entropy_summary['image_path'].apply(
    lambda x: os.path.basename(os.path.dirname(x)))

g = sns.displot(
    image_entropy_summary, x="entropy", row="class", binwidth=0.1, height=3, aspect=2.5,
    facet_kws=dict(margin_titles=True)
).set(xlim=(6.5, None))
g.fig.suptitle('Entropy Distribution by Class', y=1.01)
Out[62]:
Text(0.5, 1.01, 'Entropy Distribution by Class')
No description has been provided for this image
InĀ [63]:
invalid_image_paths = []
Samples with Very Low Entropy¶

The images below have very low entropy (i.e. in the bottom 0.5th percentile).

We can see that while some images are actual mushrooms that were photographed against a single color background (probably in a studio etc.) the image with 0.0 entropy is not valid. Additionally, we can see that one of the images is not an actual mushroom but just a random pattern (unfortunately our color variance based approach is not particularly useful at identifying images as such unless the pattern is very simple)

InĀ [65]:
entropy_thresh = np.percentile(image_entropy_summary["entropy"], 0.5, axis=0)
filtered_image_paths = [path for path, entropy in zip(image_paths_list, entropies) if entropy < entropy_thresh]
filtered_entropies = [entropy for entropy in entropies if entropy < entropy_thresh]
plot_filtered_images_by_entropy(filtered_image_paths, filtered_entropies, images_per_row=5)
No description has been provided for this image
InĀ [66]:
invalid_image_paths.append("0051_rBIC-Uy9KzI.jpg")
invalid_image_paths.append("0127_1R8TZJseXgY.jpg")
High Entropy Images¶

These images tend to show very colorful mushrooms and have colorful varying forest backgrounds.

InĀ [67]:
top_n = 25
sorted_images_by_entropy = sorted(zip(image_paths_list, entropies), key=lambda x: x[1], reverse=True)[:top_n]
filtered_image_paths, filtered_entropies = zip(*sorted_images_by_entropy)
plot_filtered_images_by_entropy(filtered_image_paths, filtered_entropies, 5)
No description has been provided for this image

Color Chanel Distribution by Class¶

These plots show the normalized intensity (0 - 255) distributions of color channel by class. The Y show the normalized frequency (density) relative to all color channels (based on highest individual value for any channel).

The charts are made by generating a histogram for each image, normalizing it (normalization process maintains the shape of the histogram, meaning the relative distribution of pixel intensities is preserved) All histograms in the class are then averaged.

As we would expect greend and red tend to be dominant in most images which reflects the color of most images type and forest floor background they tend to be photographed against.

InĀ [68]:
plot_color_distributions(color_distributions)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Data Sampling¶

Our next set is to split the dataset into test and training samples roughly trying to maintain an around 15% test sample size. The sample ratios used in our model training:

  • Test: 15% of full dataset:
  • Validation: 20% of remaining samples (i.e. 17% of full dataset)
  • Remaining images used for training: 64%

We're using stratification to maintain an around a relatively representative distribution for all classes. The invalid images found during our analysis are also excluded

InĀ [Ā ]:
importlib.reload(process_data_set)
process_data_set.stratified_split(
    process_data_set.UNPROC_DATASET_LOC, process_data_set.OUTPUT_NAME, 0.15, invalid_images=invalid_image_paths
)
InĀ [70]:
dataset_summary_df = process_data_set.verify_dataset(process_data_set.OUTPUT_NAME)
Verification Passed: No overlapping files between train and test sets.
InĀ [71]:
count_columns = ["Total Samples", "Training Samples", "Testing Samples"]

melted_df = dataset_summary_df.reset_index().melt(id_vars=["index"], value_vars=count_columns, var_name="Sample Type",
                                                  value_name="Count")

melted_df = melted_df.rename(columns={"index": "Class"})

plt.figure(figsize=(12, 6))
sns.barplot(x="Class", y="Count", hue="Sample Type", data=melted_df)
plt.title('Distribution of Samples per Class')
plt.xticks(rotation=45)
plt.show()
No description has been provided for this image
InĀ [7]:
dataset_summary_df
Out[7]:
Training Samples Training Proportion (%) Testing Samples Testing Proportion (%) Total Samples
Boletus 909 84.95% 161 15.05% 1070
Suillus 264 84.89% 47 15.11% 311
Cortinarius 709 85.01% 125 14.99% 834
Russula 972 85.04% 171 14.96% 1143
Agaricus 298 84.90% 53 15.10% 351
Amanita 636 85.03% 112 14.97% 748
Entoloma 309 84.89% 55 15.11% 364
Hygrocybe 268 85.08% 47 14.92% 315
Lactarius 1274 84.99% 225 15.01% 1499